package classification.weibo_feeds.ml

import org.ansj.splitWord.analysis.ToAnalysis
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

import scala.collection.JavaConversions._

/**
  * Created by peibin on 2017/3/20.
  */
object WeiboFeedsNaiveBayes {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("WeiboFeeds")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._


    val feeds = spark.read.textFile("""file:///Users/peibin/workspace/data-mining/带有转发和情感标签的微博数据/data/*.txt""")
      .map(values => {
        val tuples = values.split("""\|\*\*\|""")
        try {
          tuples.length match {
            case 3 => Feeds(tuples(0).toInt, tuples(1), tuples(2).toInt, null)
            case 4 => Feeds(tuples(0).toInt, tuples(1), tuples(2).toInt, tuples(3).toInt)
          }
        } catch {
          case _ => Feeds(null, null, null, null)
        }
      }).filter(feeds => feeds.feed != null)


    feeds.cache()
    feeds.show(10)

    val splitter = new Splitter().setInputCol("feed").setOutputCol("words")

    val wordsData = splitter.transform(feeds)

    val hashingTF = new HashingTF()
      .setInputCol("words").setOutputCol("rawFeatures")
    //.setNumFeatures(20)
    val featurizedData = hashingTF.transform(wordsData)

    val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    val idfModel = idf.fit(featurizedData)

    val rescaledData = idfModel.transform(featurizedData)
    val Array(trainingData, testData) = rescaledData.randomSplit(Array(0.6, 0.4))

    // Train a NaiveBayes model.
    val model = new NaiveBayes().setFeaturesCol("features").setLabelCol("sentiment")
      .fit(trainingData)


    // Select example rows to display.
    val predictions = model.transform(testData)

    val right = predictions.filter(r => r.getAs[Double]("prediction").toInt == r.getAs[Int]("sentiment")).count()
    val total = predictions.count()
    println(s"${right} / ${total} = ${1.0 * right / total}")


    // Select (prediction, true label) and compute test error
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("sentiment")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println("Test set accuracy = " + accuracy)

    spark.stop()

  }

  case class Feeds(poster: Integer, feed: String, sentiment: Integer, reTweetd: Integer);

  //class Splitter extends Tokenizer {
  class Splitter(override val uid: String)
    extends UnaryTransformer[String, Seq[String], Splitter] with DefaultParamsWritable {

    def this() = this(Identifiable.randomUID("tok"))

    override protected def createTransformFunc: (String) => Seq[String] = x => ToAnalysis.parse(x).getTerms.map(_.getRealName)

    override protected def outputDataType: DataType = new ArrayType(StringType, true)

    //override val uid: String = Identifiable.randomUID("Splitter")


  }
}
